"""
Major library for graph dictionary learning algorithms

"""
# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
# pylint: disable=no-else-return

from time import time
import os
import pickle
from collections import defaultdict
from multiprocessing import Process, Manager
from ctypes import c_char_p
import heapq
from operator import itemgetter
import shutil
import itertools

import numpy as np
from sklearn import linear_model
from sklearn.utils.extmath import randomized_svd
import scipy
import scipy.sparse as sp
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

# from utils import load_train_conf, fast_numpy_slicing
from model import MLP
from data_utils import scipy_coo_to_torch_sparse, InfiniteLooper, my_lil_matrix
from utils import torch_batch_matrix_list_mul_matrix, torch_batch_matrix_mul_matrix_list, CPU_Unpickler

DEBUG = True
TB = False

if TB:
    import tensorflow as tf
    import tensorboard as tb
    tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# theta name convention outside this class is different from it. Outside this class, the rows
# of theta is the indices for each node. Inside this class, the columns of theta is. Therefore
# thetaT is used to denote the row-wise case. 
class SDMP: # pylint:disable=too-many-instance-attributes
    """
    Main class to learn the sparse decompositio of message passing
    """
    # @profile
    def __init__(self, # pylint: disable=too-many-arguments, dangerous-default-value
                 X,
                 Omega,
                 theta_cand,
                 h_init_theta,
                 train_conf,
                 device="cpu",
                 verbose=True):
        self.X = X
        self.Omega = Omega
        self.theta_cand = theta_cand
        self.h_init_theta = h_init_theta
        self.train_conf = train_conf
        self.epoch = train_conf["epoch"]
        self.batch_size = train_conf["batch_size"]
        self.eval_step = train_conf["eval_step"]
        self.eval_batch_size = train_conf["eval_batch_size"]
        self.split_seed = train_conf["split_seed"]
        self.inductive_train = train_conf["inductive_train"]
        self.inductive_train_ratio = train_conf["inductive_train_ratio"]
        self.partial_test = train_conf["partial_test"]
        self.partial_test_ratio = train_conf["partial_test_ratio"]
        self.theta_n_nonzero = train_conf["theta_n_nonzero"]
        self.theta_cand_mode = train_conf["theta_cand_mode"]
        self.h_init_epoch = train_conf["h_init_epoch"]
        self.h_hidden = train_conf["h_hidden"]
        self.h_loop_cnt = train_conf["h_loop_cnt"]
        self.h_lr = train_conf["h_lr"]
        self.h_l2 = train_conf["h_l2"]
        self.h_dropout = train_conf["h_dropout"]
        self.h_extra_sample = train_conf["h_extra_sample"]
        self.h_extra_sample_size = train_conf["h_extra_sample_size"]
        self.device = device
        self.verbose = verbose

        self.DEBUG_info = []

        # initialize parameters
        self.log = defaultdict(list)
        self.data_size, self.feature_dim = self.X.shape
        _, self.GNN_encoder_dim = self.Omega.shape
        self.global_iter_cnt = None
        self._iter_cnt = None
        if self.eval_batch_size == -1:
            self.eval_batch_size = self.data_size
        self._sample_idx_full, self._sample_idx_ind, self._max_iter_full,\
            self._max_iter_ind, self.partial_idx_test = self._init_sampler()

        self.torch_features = None
        self.ThetaT, self.H = None, None # ThetaT consistent with paper notation
        ## initialize h function
        self.h = MLP(self.feature_dim, self.h_hidden, self.GNN_encoder_dim).to(self.device)
        ## initialize optimizer
        self.h_opt = torch.optim.Adam(self.h.parameters(),
                                      lr=self.h_lr,
                                      weight_decay=self.h_l2)
        self.h_loss_func = nn.MSELoss()

        # precomputing
        if self.theta_cand_mode in ['sparse', 'mixed', 'dense']:
            assert self.theta_cand is not None
            # self.theta_cand = self.theta_cand.tolil().rows
            self.theta_cand = self._init_theta_cand_csr(self.theta_cand)
            
    def _init_theta_cand_csr(self, csr_cand):
        num_rows = csr_cand.shape[0]
        # data = csr_cand.data
        indptr = csr_cand.indptr
        indices = csr_cand.indices
        return [indices[indptr[i]:indptr[i+1]] for i in range(num_rows)]

    def _init_sampler(self):
        sample_idx_full = np.arange(self.data_size)
        # for random selection of a proportion of nodes
        cnt_ind_train = int(np.ceil(self.inductive_train_ratio * self.data_size))
        cnt_partial_test = int(np.ceil(self.partial_test_ratio * self.data_size))
        if self.inductive_train and self.partial_test:
            assert cnt_ind_train + cnt_partial_test <= self.data_size
        np.random.seed(self.split_seed)
        np.random.shuffle(sample_idx_full)
        sample_idx_ind = sample_idx_full[:cnt_ind_train]
        partial_idx_test = sample_idx_full[-cnt_partial_test:]

        max_iter_full = int(np.ceil(self.data_size / self.batch_size))
        max_iter_ind = int(np.ceil(len(sample_idx_ind) / self.batch_size))
        return sample_idx_full, sample_idx_ind, max_iter_full, max_iter_ind, partial_idx_test

    def _sampler(self, cand_idx, max_iter):
        """handle the random mini-batch indices"""
        # initialize the parameters
        np.random.shuffle(cand_idx)
        local_data_size = len(cand_idx)
        # main loop
        for self._iter_cnt in range(max_iter):
            yield cand_idx[self._iter_cnt*self.batch_size:
                           min((self._iter_cnt+1)*self.batch_size, local_data_size)]

    def _init_Theta(self):
        # simplest lil approach
        # ThetaT = self.h_init_theta.tolil()
        
        # faster approaches
        num_rows = self.h_init_theta.shape[0]
        data = self.h_init_theta.data
        indptr = self.h_init_theta.indptr
        indices = self.h_init_theta.indices
        
        ThetaT = sp.lil_matrix((self.data_size, self.data_size), dtype='float')
        ThetaT.rows = np.array([indices[indptr[i]:indptr[i+1]].tolist() for i in range(num_rows)])
        ThetaT.data = np.array([data[indptr[i]:indptr[i+1]].tolist() for i in range(num_rows)])
        # my_rows = [indices[indptr[i]:indptr[i+1]].tolist() for i in range(num_rows)]
        # my_data = [data[indptr[i]:indptr[i+1]].tolist() for i in range(num_rows)]
        # ThetaT = my_lil_matrix(my_data, my_rows)
        return ThetaT

    # @profile
    def get_torch_H(self, idx=None, eval_batch_size=64, local_verbose=True):
        tic = time()
        if idx is None:
            idx = np.arange(self.data_size)
        max_batch = int(np.ceil(len(idx)/eval_batch_size))
        self.h.eval() # set to eval mode!
        list_H = []
        if self.verbose and DEBUG and local_verbose:
            print("Started inferencing H...")
            print("", end="")
        for bb in range(max_batch):
            cur_idx = idx[bb*eval_batch_size:
                        min((bb+1)*eval_batch_size, self.data_size)]
            cur_X = torch.from_numpy(
                self.X[cur_idx, :])\
                .float().to(self.device)
            list_H.append(self.h(cur_X)) # pylint: disable=not-callable
            if self.verbose and DEBUG and local_verbose:
                cur_time = time() - tic
                ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                print(f"\r{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                      f"ETA {ETA:.1f}",
                      end="",
                      flush=True)
        if self.verbose and DEBUG and local_verbose:
            print()
        H = torch.cat(list_H, dim=0)
        return H

    def eval_metrics_torch(self, target_nodes=None, local_verbose=True):
        if self.verbose and DEBUG and local_verbose:
            print("Evaluating metrics...")
        torch_H = self.get_torch_H(eval_batch_size=self.eval_batch_size,
                                   local_verbose=local_verbose)

        if self.verbose and DEBUG and local_verbose:
            print("Started evaluating the metrics...")
            print("", end="")
        list_diff, list_norm = [], []
        if target_nodes is None:
            target_nodes = np.arange(self.data_size)
        max_batch = int(np.ceil(len(target_nodes) / self.eval_batch_size))
        data_size = len(target_nodes)
        with torch.no_grad():
            tic = time()
            for bb in range(max_batch):
                cur_idx = target_nodes[bb*self.eval_batch_size:
                                       min((bb+1)*self.eval_batch_size, len(target_nodes))]
                cur_ThetaT = scipy_coo_to_torch_sparse(
                    self.ThetaT[cur_idx, :].tocoo())\
                    .float().to(self.device)
                ThetaT_dot_H = torch.sparse.mm(cur_ThetaT, torch_H)
                cur_Omega = torch.from_numpy(self.Omega[cur_idx, :])\
                                .float().to(self.device)
                diff_row_square_sum = torch.norm(ThetaT_dot_H - cur_Omega, dim=1) ** 2
                self_norm_square_sum = torch.norm(cur_Omega, dim=1) ** 2

                list_diff.append(diff_row_square_sum.cpu().detach().numpy())
                list_norm.append(self_norm_square_sum.cpu().detach().numpy())
                if self.verbose and DEBUG and local_verbose:
                    cur_time = time() - tic
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    print(f"\r{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                          f"ETA {ETA:.1f}",
                          end="",
                          flush=True)
                # end of main loop
            if self.verbose and DEBUG and local_verbose:
                print()
            diff = np.concatenate(list_diff)
            norm = np.concatenate(list_norm)
            regret = np.sum(diff)
            rel_regret = np.sum(diff/norm) / data_size
            return regret, rel_regret

    def eval_and_log(self, target_nodes=None):
        this_regret, this_rel_regret = self.eval_metrics_torch(target_nodes=target_nodes, local_verbose=False)

        self.log["regret"].append(this_regret)
        self.log["rel_regret"].append(this_rel_regret)
        if target_nodes is None:
            self.log["Theta_col_nonzero_cnt"].append(np.sum(self.ThetaT!=0, axis=1))
        else:
            self.log["Theta_col_nonzero_cnt"].append(np.sum(self.ThetaT[target_nodes,:]!=0, axis=1))
        self.log["Theta_col_nonzero_stat"].append(
            (np.mean(self.log["Theta_col_nonzero_cnt"][-1]),
             np.std(self.log["Theta_col_nonzero_cnt"][-1]))
        )

    def _display_stat(self, i=-1):
        print(f"    Regret: {self.log['regret'][i]:.4f}"
              f" | Rel Regret: {self.log['rel_regret'][i]:.4f}"
              f" || Theta col nonzeros: {self.log['Theta_col_nonzero_stat'][i][0]:.3f} " + u"\u00B1"
              f" {self.log['Theta_col_nonzero_stat'][i][1]:.3f}")

    def _display_more_stat(self, i=-1):
        print(f"    Times: Theta all {self.log['time_Theta'][i]:.1f}"
              f" | Theta pre {self.log['time_Theta_preprocess'][i]:.1f}"
              f" | Theta lar {self.log['time_Theta_lar'][i]:.1f}"
              f" | h all {self.log['time_h'][i]:.1f}")

    def fit(self, eval_level=0): # pylint: disable=too-many-statements
        """
        eval_level:
            0 for fast eval.
            1 for full eval at begining only.
            2 for full eval at every evaluation. 
        """
        if self.verbose:
            print("Initializing...")
        self.ThetaT = self._init_Theta()
        if eval_level >= 1:
            self.post_update(target_nodes=np.arange(self.data_size))
        self.eval_and_log()
        if self.verbose:
            print("Training started...")
            self._display_stat()
            print("-"*27, end="\n\n")
        # preprocessing
        tic_pre = time()
        if self.verbose:
            print("Preprocessing...")
        if self.inductive_train:
            train_target_nodes = self._sample_idx_ind
        else:
            train_target_nodes = self._sample_idx_full
        if self.h_init_theta is None:
            if self.verbose:
                print("h_init_theta not provided, skipping pretraining h...")
        elif self.h_init_epoch <= 0:
            if self.verbose:
                print("h_init_epoch is less than 0, skipping pretrainingh ...")
        else:
            self.init_process_h(target_nodes=train_target_nodes)
        self.log['time_preprocess'].append(time()-tic_pre)
        if self.verbose:
            print(f"Preprocessing finished in {self.log['time_preprocess'][-1]:.1f}")
        
        if eval_level >= 1:
            self.post_update(target_nodes=np.arange(self.data_size))
            self.eval_and_log(target_nodes=np.arange(self.data_size))
        else:
            self.eval_and_log(target_nodes=train_target_nodes)
        if self.verbose:
            self._display_stat()
            print("Main loop begins...")
            print("-"*27, end="\n\n")
        # main loop
        self.global_iter_cnt = 0
        max_iter = self._max_iter_ind if self.inductive_train else self._max_iter_full
        cand_idx = self._sample_idx_ind if self.inductive_train else self._sample_idx_full
        if self.h_extra_sample:
            h_cand_idx = cand_idx[:]
            h_extra_sampler = iter(InfiniteLooper(h_cand_idx, self.h_extra_sample_size))
        total_iter = self.epoch * max_iter
        tic_start = time()
        for e in range(self.epoch):
            for it, cur_idx in enumerate(self._sampler(cand_idx, max_iter)):
                # tic_iter_start = time()
                self.global_iter_cnt += 1
                # phase Theta
                tic_Theta_start = time()
                self.update_Theta(cur_idx)
                self.log['time_Theta'].append(time()-tic_Theta_start)
                # phase h
                tic_h_start = time()
                if self.h_extra_sample:
                    cur_h_idx = np.concatenate((cur_idx,next(h_extra_sampler)))
                else:
                    cur_h_idx = cur_idx
                self.update_h(cur_h_idx)
                self.log['time_h'].append(time()-tic_h_start)
                # evaluation and display
                if self.global_iter_cnt % self.eval_step ==0:
                    tic_eval = time()
                    if eval_level >= 2:
                        self.post_update(target_nodes=np.arange(self.data_size))
                        self.eval_and_log(target_nodes=np.arange(self.data_size))
                    else:
                        self.eval_and_log(target_nodes=train_target_nodes)
                    eval_time = time() - tic_eval
                    if self.verbose:
                        elapsed_time = time() - tic_start
                        ETA = elapsed_time /\
                            self.global_iter_cnt * (total_iter - self.global_iter_cnt)
                        print("-"*5)
                        print(f"Epoch: {e} | Iter: {it} | Global iter: {self.global_iter_cnt} "
                              f"Eval time: {eval_time:.1f} "
                              f"Elapsed time: {elapsed_time:.1f} | ETA: {ETA:.1f}")
                        self._display_stat()
                        if DEBUG: # print more info for debugging mode
                            self._display_more_stat()
                            # break

        # Post update
        if self.verbose:
            print("-"*27)
            print()
            print("Started post processing...")
            if self.partial_test:
                print("Partial testing...")
        test_target_nodes = self.partial_idx_test if self.partial_test else np.arange(self.data_size)
        self.post_update(target_nodes = test_target_nodes)
        self.eval_and_log(target_nodes = test_target_nodes)
        if self.verbose:
            print(f"Training finished in {(time()-tic_start):.1f} s.")
            self._display_stat()
        return self
        # end fit

    # @profile
    def update_Theta(self, cur_idx, log=True):
        time_preprocessing, time_lar = 0.0, 0.0
        # prepare the data
        tic = time()
        receptive_idx, local_map = self.get_batch_Theta_candidate(cur_idx)
        torch_cur_H = self.get_torch_H(idx=receptive_idx, local_verbose=False)
        # full_gram = torch.mm(torch_cur_H, torch_cur_H.t()).cpu().detach().numpy()
        time_preprocessing += time() - tic
        # execute LAR and update the results
        for i, i_local_map in zip(cur_idx, local_map):
            tic_i_pre_start = time()
            i_target = self.Omega[np.array([i]), :]
            torch_i_H = torch_cur_H[i_local_map, :]
            # my_gram = full_gram[i_local_map, :][:, i_local_map]
            # my_gram = fast_numpy_slicing(full_gram, i_local_map, i_local_map)
            my_gram = torch.mm(torch_i_H, torch_i_H.t()).cpu().detach().numpy()
            X = torch_i_H.cpu().detach().numpy().transpose()
            y = i_target.transpose()
            Xy = torch.mm(torch_i_H, torch.from_numpy(y).to(self.device))
            Xy = Xy.cpu().detach().numpy()
            time_preprocessing += time() - tic_i_pre_start
            # execute Lars
            tic_lars_start = time()
            try:
                reg = linear_model.LassoLars(alpha=0, precompute=my_gram, max_iter=self.theta_n_nonzero,
                                            normalize=False, fit_intercept=False, positive=True)
                reg.fit(X=X, y=y, Xy=Xy)
                # collect the results and index mapping
                local_res = np.array(reg.coef_)
                local_res_coo = sp.coo_matrix(local_res)
                self.ThetaT.rows[i] = receptive_idx[i_local_map[local_res_coo.col]].tolist()
                self.ThetaT.data[i] = local_res_coo.data.tolist()
            except Exception as e:
                cur_bug_info = {}
                cur_bug_info["iter"] = self.global_iter_cnt
                cur_bug_info['node_id'] = i
                cur_bug_info['H'] = X
                self.DEBUG_info.append(cur_bug_info)
                if self.verbose:
                    print(f"!!! Bug caught with node {i} at iteration {self.global_iter_cnt}.")
                    print(str(e))
            time_lar += time() - tic_lars_start
        if log:
            self.log['time_Theta_preprocess'].append(time_preprocessing)
            self.log['time_Theta_lar'].append(time_lar)

    def update_h(self, cur_idx):
        loss = self.train_h_batch(cur_idx, self.ThetaT)
        self.log['h_loss'].append(loss)

    def get_batch_Theta_candidate(self, cur_idx):
        if self.theta_cand_mode == "full":
            receptive_idx = np.arange(self.data_size).astype(int)
            aligned_local_map = [receptive_idx for _ in range(len(cur_idx))]
        elif self.theta_cand_mode in ["sparse", "mixed", "dense"]:
            cur_candidate = [self.theta_cand[i] for i in cur_idx]
            pivot = [0] + [len(i) for i in cur_candidate]
            pivot = np.cumsum(pivot)
            flat_candidate = list(itertools.chain.from_iterable(cur_candidate))
            receptive_idx, local_map = np.unique(flat_candidate, return_inverse=True)
            receptive_idx = receptive_idx.astype(int)
            aligned_local_map = [local_map[pivot[i]:pivot[i+1]] for i in range(len(pivot)-1)]
        else:
            raise ValueError("Unrecoganized theta candidate mode"
                             f"{self.theta_cand_mode}.")

        return receptive_idx, aligned_local_map

    def train_h_batch(self, batch_idx, ThetaT):
        self.h.train()
        # prepare samples
        cur_ThetaT = ThetaT[batch_idx, :].tocoo()
        X_ind, local_map = np.unique(cur_ThetaT.col, return_inverse=True)
        cur_ThetaT_local = sp.coo_matrix((cur_ThetaT.data, (cur_ThetaT.row, local_map)),
                                         shape=(cur_ThetaT.shape[0],len(X_ind)))
        cur_ThetaT_local = scipy_coo_to_torch_sparse(cur_ThetaT_local).to(self.device)
        target = torch.from_numpy(self.Omega[batch_idx, :]).to(self.device)
        # construct the loss
        H = self.get_torch_H(idx=X_ind, local_verbose=False)
        pred = torch.sparse.mm(cur_ThetaT_local, H)
        loss = self.h_loss_func(pred, target)
        # execute the training
        self.h_opt.zero_grad()
        loss.backward()
        self.h_opt.step()
        return loss.item()

    def init_process_h(self, target_nodes=None, local_verbose=True):
        """
        Initialize the h by training several epochs of h based on graphs
        """
        init_ThetaT = self.h_init_theta.tolil()
        if target_nodes is None:
            target_nodes = self._sample_idx_full
        cand_idx = target_nodes
        bb = 0
        tic = time()
        if self.verbose and local_verbose:
            print("", end='')
        used_epoch = max(1, self.h_init_epoch)
        if self.h_init_epoch < 1:
            used_idx = cand_idx[:]
            cut_cnt = int(self.h_init_epoch * len(cand_idx))
            np.random.shuffle(used_idx)
            used_idx = used_idx[:cut_cnt]
        else:
            used_idx = cand_idx
        max_iter = int(np.ceil(len(used_idx) / self.batch_size))
        max_batch = max(self.h_init_epoch, 1) * max_iter
        for _ in range(used_epoch):
            for _, cur_idx in enumerate(self._sampler(used_idx, max_iter)):
                cur_loss = self.train_h_batch(cur_idx, init_ThetaT)
                self.log["h_loss"].append(cur_loss)
                if self.verbose and local_verbose:
                    cur_time = time() - tic
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    print(f"\r {bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                          f"ETA {ETA:.1f}", flush=True, end="")
                bb += 1
                # end of main loop
            if self.verbose and local_verbose:
                print()

    def post_update(self, target_nodes=None):
        if self.verbose:
            print("Post precessing... Recomputing Theta...")
        tic_start = time()
        if self.verbose:
            print("", end="")
        if target_nodes is None:
            target_nodes = self._sample_idx_full
        max_iter = int(np.ceil(len(target_nodes)/self.batch_size))
        for it, cur_idx in enumerate(self._sampler(target_nodes, max_iter)):
            self.update_Theta(cur_idx, log=False)
            if self.verbose:
                cur_time = time() - tic_start
                ETA = cur_time / (it + 1) * (max_iter - it - 1)
                print(f'\r {it/max_iter:.3f} finished in {cur_time:.1f} s. ETA: {ETA:.1f} s',
                      flush=True, end='')
        if self.verbose:
            print()

    def compute_ThetaT_from_h(self, target_nodes=None):
        self.ThetaT = sp.lil_matrix((self.data_size, self.data_size), dtype='float')
        self.post_update(target_nodes=target_nodes)
        self.eval_and_log(target_nodes=target_nodes)
        self._display_stat()

    def infer_torch_node_approximal_features(self, return_H=False):
        # infer H
        H = self.get_torch_H().detach()
        # get final results
        torch_ThetaT = scipy_coo_to_torch_sparse(self.ThetaT.tocoo()).to(self.device)
        self.torch_features = torch.sparse.mm(torch_ThetaT, H).detach()
        if return_H:
            return self.torch_features, H
        else:
            return self.torch_features

    def infer_torch_node_approximal_features_idx(self, idx):
        cur_ThetaT = self.ThetaT[idx, :].tocoo()
        X_ind, local_map = np.unique(cur_ThetaT.col, return_inverse=True)
        cur_ThetaT_local = sp.coo_matrix((cur_ThetaT.data, (cur_ThetaT.row, local_map)),
                                         shape=(cur_ThetaT.shape[0],len(X_ind)))
        cur_ThetaT_local = scipy_coo_to_torch_sparse(cur_ThetaT_local).to(self.device)
        # construct the loss
        H = self.get_torch_H(idx=X_ind, local_verbose=False)
        pred = torch.sparse.mm(cur_ThetaT_local, H)
        return pred

    ########## Node wise inference module#####
    # @profile
    def efficient_node_wise_infer(self, idx):
        try:
            weight = torch.tensor(self.ThetaT_data[idx]).to(self.device).reshape(1, -1)
            H = self.efficient_get_torch_H(idx=self.ThetaT_rows[idx])
            pred = weight @ H
        except:
            pred = torch.zeros([1, self.GNN_encoder_dim]) # preventing the degenerating case of zeros indices
        return pred

    # @profile
    def efficient_prepare(self):
        self.ThetaT_rows = [torch.tensor(i).to(self.device) for i in self.ThetaT.rows]
        self.ThetaT_data = [torch.tensor(i).float().to(self.device) for i in self.ThetaT.data]
        self.h.eval()
        self.torch_X = torch.from_numpy(self.X).float().to(self.device)

    # @profile
    def efficient_get_torch_H(self, idx=None):
        tic = time()
        cur_X = self.torch_X[idx]
        return self.h(cur_X)
    ############################################

    def save(self, res_folder):
        if not self.partial_test:
            with open(os.path.join(res_folder, "ThetaT.pkl"), 'wb') as fout:
                pickle.dump(self.ThetaT, fout)
        with open(os.path.join(res_folder, "log.pkl"), "wb") as fout:
            pickle.dump(self.log, fout)
        with open(os.path.join(res_folder, "h_model_stat.pkl"), "wb") as fout:
            pickle.dump(self.h.state_dict(), fout)
        with open(os.path.join(res_folder, "larlasso_debug.pkl"), 'wb') as fout:
            pickle.dump(self.DEBUG_info, fout, protocol=4)
        if self.verbose:
            print(f"Results are saved in {res_folder}.")

    def save_ThetaT(self, res_folder):
        with open(os.path.join(res_folder, "ThetaT.pkl"), 'wb') as fout:
            pickle.dump(self.ThetaT, fout, protocol=4)

    def load(self, res_folder, theta_name="ThetaT.pkl", log_name="log.pkl", h_name="h_model_stat.pkl"):
        if self.verbose:
            print(f"Loading results from {res_folder}.")
        if os.path.exists(os.path.join(res_folder, theta_name)):
            with open(os.path.join(res_folder, theta_name), 'rb') as fin:
                self.ThetaT = pickle.load(fin)
        with open(os.path.join(res_folder, log_name), "rb") as fin:
            self.log = pickle.load(fin)
        # with open(os.path.join(res_folder, h_name), "rb") as fin:
        #     self.h.load_state_dict(pickle.load(fin))
        with open(os.path.join(res_folder, h_name), "rb") as fin:
            self.h.load_state_dict(CPU_Unpickler(fin).load())        
         
def wrap_log_for_fig(log_dict, train_conf, result_folder=None):
    """
    Specficially designed for figure drawing
    """
    def wrap_time(foo, eval_step):
        cnt = len(foo)
        display_cnt = int(cnt / eval_step)
        return [np.sum(foo[i*eval_step:(i+1)*eval_step]) for i in range(display_cnt)]

    x, y, tag = None, None, None
    eval_step = train_conf["eval_step"]
    theta_time = wrap_time(log_dict["time_Theta"], eval_step)
    h_time = wrap_time(log_dict["time_h"], eval_step)

    all_time = [ a + b for a, b in zip(theta_time, h_time)]
    x = [0, log_dict["time_preprocess"][0]] + all_time
    x = [np.sum(x[:i]) for i in range(1, len(x)+1)]
    y = log_dict["rel_regret"][:len(x)]

    # saving the results
    if result_folder is not None:
        with open(os.path.join(result_folder, "rel_regret_vs_time.csv"), "w") as fout:
            for i, j in zip(x, y):
                fout.write(str(i)+","+str(j)+"\n")
        if tag is not None:
            with open(os.path.join(result_folder, "tag.txt"), "w") as fout:
                fout.write(tag)
    return x, y, tag